{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# XGBoost Parameter Tuning for Rent Listing Inqueries\n",
    "Rental Listing Inquiries数据集是Kaggle平台上的一个分类竞赛任务，需要根据公寓的特征来预测其受欢迎程度（用户感兴趣程度分为高、中、低三类）。其中房屋的特征x共有14维，响应值y为用户对该公寓的感兴趣程度。评价标准为logloss。 数据链接：https://www.kaggle.com/c/two-sigma-connect-rental-listing-inquiries"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### xgboost调参一般步骤\n",
    "<br/>1.先确定一个较大的<font color=#D02090>学习率（learning rate）</font>，一般取0.1，再用交叉验证选取合适的<font color=#D02090>树的数量(n_estimators)</font></br>\n",
    "<br/>2.对于给定的学习率和树个数，进行<font color=#D02090>树参数</font>调优（max_depth,&nbsp; min_child_weight,&nbsp; gamma,&nbsp; subsample,&nbsp; colsample_bytree, &nbsp; colsample_bylevel）</br>\n",
    "<br/>3.xgboost<font color=#D02090>正则参数（lambda, alpha）</font>调优</br>\n",
    "<br/>4.<font color=#D02090>降低学习率，确定树的个数</font>，然后跳到第一步再来次循环</br>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "from xgboost import XGBClassifier\n",
    "import xgboost as xgb\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "\n",
    "from sklearn.metrics import log_loss\n",
    "\n",
    "from matplotlib import pyplot\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据读取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>bathrooms</th>\n",
       "      <th>bedrooms</th>\n",
       "      <th>price</th>\n",
       "      <th>price_bathrooms</th>\n",
       "      <th>price_bedrooms</th>\n",
       "      <th>room_diff</th>\n",
       "      <th>room_num</th>\n",
       "      <th>Year</th>\n",
       "      <th>Month</th>\n",
       "      <th>Day</th>\n",
       "      <th>...</th>\n",
       "      <th>walk</th>\n",
       "      <th>walls</th>\n",
       "      <th>war</th>\n",
       "      <th>washer</th>\n",
       "      <th>water</th>\n",
       "      <th>wheelchair</th>\n",
       "      <th>wifi</th>\n",
       "      <th>windows</th>\n",
       "      <th>work</th>\n",
       "      <th>interest_level</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.5</td>\n",
       "      <td>3</td>\n",
       "      <td>3000</td>\n",
       "      <td>1200.0</td>\n",
       "      <td>750.000000</td>\n",
       "      <td>-1.5</td>\n",
       "      <td>4.5</td>\n",
       "      <td>2016</td>\n",
       "      <td>6</td>\n",
       "      <td>24</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.0</td>\n",
       "      <td>2</td>\n",
       "      <td>5465</td>\n",
       "      <td>2732.5</td>\n",
       "      <td>1821.666667</td>\n",
       "      <td>-1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>2016</td>\n",
       "      <td>6</td>\n",
       "      <td>12</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>2850</td>\n",
       "      <td>1425.0</td>\n",
       "      <td>1425.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>2016</td>\n",
       "      <td>4</td>\n",
       "      <td>17</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>3275</td>\n",
       "      <td>1637.5</td>\n",
       "      <td>1637.500000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>2016</td>\n",
       "      <td>4</td>\n",
       "      <td>18</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1.0</td>\n",
       "      <td>4</td>\n",
       "      <td>3350</td>\n",
       "      <td>1675.0</td>\n",
       "      <td>670.000000</td>\n",
       "      <td>-3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>2016</td>\n",
       "      <td>4</td>\n",
       "      <td>28</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 228 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   bathrooms  bedrooms  price  price_bathrooms  price_bedrooms  room_diff  \\\n",
       "0        1.5         3   3000           1200.0      750.000000       -1.5   \n",
       "1        1.0         2   5465           2732.5     1821.666667       -1.0   \n",
       "2        1.0         1   2850           1425.0     1425.000000        0.0   \n",
       "3        1.0         1   3275           1637.5     1637.500000        0.0   \n",
       "4        1.0         4   3350           1675.0      670.000000       -3.0   \n",
       "\n",
       "   room_num  Year  Month  Day       ...        walk  walls  war  washer  \\\n",
       "0       4.5  2016      6   24       ...           0      0    0       0   \n",
       "1       3.0  2016      6   12       ...           0      0    0       0   \n",
       "2       2.0  2016      4   17       ...           0      0    0       0   \n",
       "3       2.0  2016      4   18       ...           0      0    0       0   \n",
       "4       5.0  2016      4   28       ...           0      0    1       0   \n",
       "\n",
       "   water  wheelchair  wifi  windows  work  interest_level  \n",
       "0      0           0     0        0     0               1  \n",
       "1      0           0     0        0     0               2  \n",
       "2      0           0     0        0     0               0  \n",
       "3      0           0     0        0     0               2  \n",
       "4      0           0     0        0     0               2  \n",
       "\n",
       "[5 rows x 228 columns]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dpath = './data/' \n",
    "train = pd.read_csv(dpath + \"RentListingInquries_FE_train.csv\")\n",
    "test = pd.read_csv(dpath + \"RentListingInquries_FE_test.csv\")\n",
    "\n",
    "train.head()\n",
    "#test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>bathrooms</th>\n",
       "      <th>bedrooms</th>\n",
       "      <th>price</th>\n",
       "      <th>price_bathrooms</th>\n",
       "      <th>price_bedrooms</th>\n",
       "      <th>room_diff</th>\n",
       "      <th>room_num</th>\n",
       "      <th>Year</th>\n",
       "      <th>Month</th>\n",
       "      <th>Day</th>\n",
       "      <th>...</th>\n",
       "      <th>walk</th>\n",
       "      <th>walls</th>\n",
       "      <th>war</th>\n",
       "      <th>washer</th>\n",
       "      <th>water</th>\n",
       "      <th>wheelchair</th>\n",
       "      <th>wifi</th>\n",
       "      <th>windows</th>\n",
       "      <th>work</th>\n",
       "      <th>interest_level</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>49352.00000</td>\n",
       "      <td>49352.000000</td>\n",
       "      <td>4.935200e+04</td>\n",
       "      <td>4.935200e+04</td>\n",
       "      <td>4.935200e+04</td>\n",
       "      <td>49352.000000</td>\n",
       "      <td>49352.000000</td>\n",
       "      <td>49352.0</td>\n",
       "      <td>49352.000000</td>\n",
       "      <td>49352.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>49352.000000</td>\n",
       "      <td>49352.000000</td>\n",
       "      <td>49352.000000</td>\n",
       "      <td>49352.000000</td>\n",
       "      <td>49352.000000</td>\n",
       "      <td>49352.000000</td>\n",
       "      <td>49352.000000</td>\n",
       "      <td>49352.000000</td>\n",
       "      <td>49352.000000</td>\n",
       "      <td>49352.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>1.21218</td>\n",
       "      <td>1.541640</td>\n",
       "      <td>3.830174e+03</td>\n",
       "      <td>1.697863e+03</td>\n",
       "      <td>1.657567e+03</td>\n",
       "      <td>-0.329460</td>\n",
       "      <td>2.753820</td>\n",
       "      <td>2016.0</td>\n",
       "      <td>5.014852</td>\n",
       "      <td>15.206881</td>\n",
       "      <td>...</td>\n",
       "      <td>0.003080</td>\n",
       "      <td>0.000385</td>\n",
       "      <td>0.186477</td>\n",
       "      <td>0.009361</td>\n",
       "      <td>0.000446</td>\n",
       "      <td>0.028165</td>\n",
       "      <td>0.002026</td>\n",
       "      <td>0.001013</td>\n",
       "      <td>0.000952</td>\n",
       "      <td>1.616895</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>0.50142</td>\n",
       "      <td>1.115018</td>\n",
       "      <td>2.206687e+04</td>\n",
       "      <td>1.100477e+04</td>\n",
       "      <td>7.817996e+03</td>\n",
       "      <td>0.947732</td>\n",
       "      <td>1.446091</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.824442</td>\n",
       "      <td>8.280749</td>\n",
       "      <td>...</td>\n",
       "      <td>0.055412</td>\n",
       "      <td>0.019618</td>\n",
       "      <td>0.389495</td>\n",
       "      <td>0.101625</td>\n",
       "      <td>0.021109</td>\n",
       "      <td>0.165446</td>\n",
       "      <td>0.044969</td>\n",
       "      <td>0.031814</td>\n",
       "      <td>0.030846</td>\n",
       "      <td>0.626035</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>0.00000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.300000e+01</td>\n",
       "      <td>2.150000e+01</td>\n",
       "      <td>4.300000e+01</td>\n",
       "      <td>-5.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>2016.0</td>\n",
       "      <td>4.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>1.00000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.500000e+03</td>\n",
       "      <td>1.225000e+03</td>\n",
       "      <td>1.066667e+03</td>\n",
       "      <td>-1.000000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>2016.0</td>\n",
       "      <td>4.000000</td>\n",
       "      <td>8.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>1.00000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>3.150000e+03</td>\n",
       "      <td>1.500000e+03</td>\n",
       "      <td>1.383417e+03</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>2016.0</td>\n",
       "      <td>5.000000</td>\n",
       "      <td>15.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>2.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>1.00000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>4.100000e+03</td>\n",
       "      <td>1.850000e+03</td>\n",
       "      <td>1.962500e+03</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.000000</td>\n",
       "      <td>2016.0</td>\n",
       "      <td>6.000000</td>\n",
       "      <td>22.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>2.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>10.00000</td>\n",
       "      <td>8.000000</td>\n",
       "      <td>4.490000e+06</td>\n",
       "      <td>2.245000e+06</td>\n",
       "      <td>1.496667e+06</td>\n",
       "      <td>8.000000</td>\n",
       "      <td>13.500000</td>\n",
       "      <td>2016.0</td>\n",
       "      <td>6.000000</td>\n",
       "      <td>31.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>8 rows × 228 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         bathrooms      bedrooms         price  price_bathrooms  \\\n",
       "count  49352.00000  49352.000000  4.935200e+04     4.935200e+04   \n",
       "mean       1.21218      1.541640  3.830174e+03     1.697863e+03   \n",
       "std        0.50142      1.115018  2.206687e+04     1.100477e+04   \n",
       "min        0.00000      0.000000  4.300000e+01     2.150000e+01   \n",
       "25%        1.00000      1.000000  2.500000e+03     1.225000e+03   \n",
       "50%        1.00000      1.000000  3.150000e+03     1.500000e+03   \n",
       "75%        1.00000      2.000000  4.100000e+03     1.850000e+03   \n",
       "max       10.00000      8.000000  4.490000e+06     2.245000e+06   \n",
       "\n",
       "       price_bedrooms     room_diff      room_num     Year         Month  \\\n",
       "count    4.935200e+04  49352.000000  49352.000000  49352.0  49352.000000   \n",
       "mean     1.657567e+03     -0.329460      2.753820   2016.0      5.014852   \n",
       "std      7.817996e+03      0.947732      1.446091      0.0      0.824442   \n",
       "min      4.300000e+01     -5.000000      0.000000   2016.0      4.000000   \n",
       "25%      1.066667e+03     -1.000000      2.000000   2016.0      4.000000   \n",
       "50%      1.383417e+03      0.000000      2.000000   2016.0      5.000000   \n",
       "75%      1.962500e+03      0.000000      4.000000   2016.0      6.000000   \n",
       "max      1.496667e+06      8.000000     13.500000   2016.0      6.000000   \n",
       "\n",
       "                Day       ...                walk         walls           war  \\\n",
       "count  49352.000000       ...        49352.000000  49352.000000  49352.000000   \n",
       "mean      15.206881       ...            0.003080      0.000385      0.186477   \n",
       "std        8.280749       ...            0.055412      0.019618      0.389495   \n",
       "min        1.000000       ...            0.000000      0.000000      0.000000   \n",
       "25%        8.000000       ...            0.000000      0.000000      0.000000   \n",
       "50%       15.000000       ...            0.000000      0.000000      0.000000   \n",
       "75%       22.000000       ...            0.000000      0.000000      0.000000   \n",
       "max       31.000000       ...            1.000000      1.000000      1.000000   \n",
       "\n",
       "             washer         water    wheelchair          wifi       windows  \\\n",
       "count  49352.000000  49352.000000  49352.000000  49352.000000  49352.000000   \n",
       "mean       0.009361      0.000446      0.028165      0.002026      0.001013   \n",
       "std        0.101625      0.021109      0.165446      0.044969      0.031814   \n",
       "min        0.000000      0.000000      0.000000      0.000000      0.000000   \n",
       "25%        0.000000      0.000000      0.000000      0.000000      0.000000   \n",
       "50%        0.000000      0.000000      0.000000      0.000000      0.000000   \n",
       "75%        0.000000      0.000000      0.000000      0.000000      0.000000   \n",
       "max        2.000000      1.000000      1.000000      1.000000      1.000000   \n",
       "\n",
       "               work  interest_level  \n",
       "count  49352.000000    49352.000000  \n",
       "mean       0.000952        1.616895  \n",
       "std        0.030846        0.626035  \n",
       "min        0.000000        0.000000  \n",
       "25%        0.000000        1.000000  \n",
       "50%        0.000000        2.000000  \n",
       "75%        0.000000        2.000000  \n",
       "max        1.000000        2.000000  \n",
       "\n",
       "[8 rows x 228 columns]"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 49352 entries, 0 to 49351\n",
      "Columns: 228 entries, bathrooms to interest_level\n",
      "dtypes: float64(9), int64(219)\n",
      "memory usage: 85.8 MB\n"
     ]
    }
   ],
   "source": [
    "train.info()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_Y = train['interest_level']\n",
    "train_X = train.drop([\"interest_level\"], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E:\\Anaconda3\\lib\\site-packages\\sklearn\\model_selection\\_split.py:2026: FutureWarning: From version 0.21, test_size will always complement train_size unless both are specified.\n",
      "  FutureWarning)\n"
     ]
    }
   ],
   "source": [
    "# 训练样本较大可以用train_test_split取一部分样本估计模型性能\n",
    "from sklearn.model_selection import train_test_split\n",
    "train_x, X_val, train_y, y_val = train_test_split(train_X, train_Y, train_size = 0.5,random_state = 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 第一步：给定学习率，确定树的数目"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 设置xgboost参数（sklearn框架下）\n",
    "xgb1 = XGBClassifier(\n",
    "        learning_rate=0.1,\n",
    "        n_estimators=1000,  #数值大没关系，cv会自动返回合适的n_estimators\n",
    "        max_depth=5,\n",
    "        min_child_weight=1,\n",
    "        gamma=0,\n",
    "        subsample=0.6,\n",
    "        colsample_bytree=0.8,\n",
    "        colsample_bylevel=0.7,\n",
    "        objective='multi:softprob',\n",
    "        seed=3\n",
    "        )\n",
    "\n",
    "#直接调用xgboost内嵌的交叉验证（cv），可对连续的n_estimators参数进行快速交叉验证\n",
    "#而GridSearchCV只能对有限个参数进行交叉验证\n",
    "def modelfit(alg, x_train, y_train, cv_folds=5, early_stopping_rounds=10):\n",
    "    \n",
    "    xgb_param = alg.get_xgb_params()\n",
    "    xgb_param['num_class'] = 3\n",
    "    \n",
    "    #直接调用xgboost，而非sklarn的wrapper类\n",
    "    xgtrain = xgb.DMatrix(x_train, label=y_train)\n",
    "    \n",
    "    cvresult = xgb.cv(xgb_param, xgtrain, num_boost_round=alg.get_xgb_params()['n_estimators'], nfold=cv_folds,\n",
    "                     metrics='mlogloss', early_stopping_rounds=early_stopping_rounds)\n",
    "    \n",
    "    cvresult.to_csv('1_nestimators.csv', index_label = 'n_estimators')\n",
    "    \n",
    "    #最佳参数n_estimators\n",
    "    n_estimators = cvresult.shape[0]\n",
    "    \n",
    "    # 采用交叉验证得到的最佳参数n_estimators，训练模型\n",
    "    alg.set_params(n_estimators=n_estimators)\n",
    "    alg.fit(x_train, y_train, eval_metric='mlogloss')\n",
    "    \n",
    "modelfit(xgb1, train_x, train_y)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'base_score': 0.5,\n",
       " 'booster': 'gbtree',\n",
       " 'colsample_bylevel': 0.7,\n",
       " 'colsample_bytree': 0.8,\n",
       " 'gamma': 0,\n",
       " 'learning_rate': 0.1,\n",
       " 'max_delta_step': 0,\n",
       " 'max_depth': 5,\n",
       " 'min_child_weight': 1,\n",
       " 'missing': None,\n",
       " 'n_estimators': 193,\n",
       " 'nthread': 1,\n",
       " 'objective': 'multi:softprob',\n",
       " 'reg_alpha': 0,\n",
       " 'reg_lambda': 1,\n",
       " 'scale_pos_weight': 1,\n",
       " 'seed': 3,\n",
       " 'silent': 1,\n",
       " 'subsample': 0.6}"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xgb1.get_xgb_params()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E:\\Anaconda3\\lib\\site-packages\\ipykernel_launcher.py:1: FutureWarning: from_csv is deprecated. Please use read_csv(...) instead. Note that some of the default arguments are different, so please refer to the documentation for from_csv when changing your function calls\n",
      "  \"\"\"Entry point for launching an IPython kernel.\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEXCAYAAAC3c9OwAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3XmcHHW1///X6e7ZMpnJZJmsk50kEGQPCMqSq6iACojIoqigV+R+L9ftul79ql+87vd6xeXHoiKuLKICIm4XEVEQCBCWJCSE7Mkkk20y+9LT5/fHpybpTHoyk5Ce6km/nw/qQXdVddXp6km/uz5V9Slzd0RERPpKxF2AiIgUJgWEiIjkpIAQEZGcFBAiIpKTAkJERHJSQIiISE4KCJEsZvYfZvb9uOsQKQQKiGHGzEaa2Roze3vWuCozW2dmF2eNW2Bm95nZTjNrNLOlZvZFMxsdTb/SzHrMrCUaVpnZv+S59oVmtiGf6zgQuepx9y+5+z/naX1rzOzsfCw7H4bq8xpu26WYKCCGGXdvAa4Grjez2mj014BF7n4XgJm9CvgL8HfgSHevAc4B0sBxWYt71N1HuvtI4GLga2Z2wtC8EzkQZpaKuwYpQu6uYRgOwK3AbcBCYDswKWva34BvD/D6K4G/9Rn3OPD2rOfnA0uARkLgHJU17ahoXGM0z/lZ084DlgLNwEbgo0Al0A5kgJZomNzP+/ou8Nvo9Y8BswexPY4E/gTsAJYDlxxMPcDngZ9Gr5sBOHAVsB7YCVwDnAw8G73372StZzbw5+jz2Ab8DKiJpv0kWld7tK6PD2IbrwE+Ea2rE0hFzzdG72U58Noc2+JUYDOQzBr3FuDZ6PEpwCKgCdgCfKOfbboQ2NDPtFHAj4GtwFrgM0AimpYE/jvaBquBa6PtmOpnWWuAs/uZ9j5gZfS53tv7NwMY8D9AA7Ar2kav6O/zjvvf63AdYi9Aw0F+cDAaqI/+EV6VNb4S6AEWDvD6K8kKiOhLrxGYGz2fC7QCrwNKgI9H/1BLo+crgf+Inr8m+sc4L3ptPXBGVp0nRo/7/cLJquPW6MvglOgL8WfA7QO8ppLwBX5V9JoTo+1y9IHWQ+6AuBEoB14PdAB3A+OBKdEX1FnR/EdE26sMqAX+Cnwza9l7fRHubxtnzb8YmApUAPOi9zk5q76c4Qm8BLwu6/kvgE9Gjx8F3hk9Hgmc2s8y+v28COFwD1AV1bECeG807RrCF3RdtL3/l4MIiOjvalv0eZYB3wb+Gk17A/AkUEMIi6OIfiT193lrOPBBTUzDlLvvJPzyHAH8KmvSaELT4ebeEWb2teg4RKuZfSZr3lOj8S2EvYefAC9G0y4Ffuvuf3L3buC/CF9SryL8Qh0JfMXdu9z9z8B9wOXRa7uB+WZW7e473f2pA3x7v3L3x909TQiI4weY/03AGnf/obuno/X9ktBsdijq+YK7d7j7Hwlf6Le5e4O7bwQeBk4AcPeV0fbqdPetwDeAs/az3P1t417fcvf17t5OCP6y6L2UuPsad3+pn2XfRvR5mFkV4Vf1bVnb4wgzG+fuLe7+jwPZGGaWjGr/lLs3u/sawh7DO6NZLgGud/cN0d/pVw5k+VneAdzi7k+5eyfwKeA0M5sRvYcqwp6jufsyd6/Pen8v5/OWiAJimDKzKwi/3P4X+GrWpJ2EpoxJvSPc/eMejkP8mvALu9c/3L3GwzGIicDRwJeiaZMJTQe9y8gQfr1Oiaatj8b1WhtNA3gr4QtprZk9ZGanHeDb25z1uI0QRvszHXhlFHaNZtZI+HKZeIjq2ZL1uD3H85EAZjbezG43s41m1gT8FBi3n+Xubxv3Wp81fSXwIcJeTkO0rsn9LPvnwEVmVgZcBDzl7r3rei9h7+UFM3vCzN60nxpzGUfYc1ybNS7785+cXXefxwei7/ZpITTfTYl+lHyH0By5xcxuNrPqaNaX+3lLRAExDJnZeEL76/uA9wOXmNmZAO7eSmi3v+hAlunuWwi/ut8cjdpE+OLtXacRmjo2RtOmmln238+0aBru/oS7X0BohrkbuLN3NQdS0wFYDzwUhV3vMNLd/2WI6/lytMxj3b0auILQ/NGr7/r2t41zvsbdf+7up0evc/b+cZA931LCl+u5wNsJgdE77UV3v5ywPb4K3GVmlYN/m2wj/EqfnjVu9+dPaOKpy5o29QCWna3v9qkExrLn7+xb7n4S4YfNXOBj0fj+Pm85QAqI4ek7wN3u/mC0W/1x4HvRr0Wi5+8xs09GYYKZ1QEz+1ugmY0lHMhcEo26E3ijmb3WzEqAfyccKH2EEECtwMfNrMTMFhKC5XYzKzWzd5jZqKjZpInQNALhl/dYMxt1iLZDr/uAuWb2zqieEjM72cyOGuJ6qggHoBvNbArRF1aWLcCsrOf728b7MLN5Zvaa6HPuIOy99OSaN/Jz4APAmYRjEL3LucLMaqM9lsZodL/LMbPy7IGwh3on8MXoFOvpwEcIe0y97+uDZjbFzGoIB9YHUtJnPamo/qvM7PjoPX8JeMzd10Sf7yuj7dYabY+eAT5vOVBxHwTRcGADcCHhl1VNn/EPAF/Mev5K4H7CF0Aj8DzwRWBsNP1Kwj+c3jN4Gght1OOzlvEWwsHGXcBDRAd9o2lHR+N2RfO8JRpfCvye0NTVBDwBnJ71ulsIzQSN9H8W039mPV/IAAe2o/nmEc582hot/8+EYxcHVA+5D1KnsubfQNYJAIQvxc9kbZMno+25mPCFvyFr3guAddG6PjqIbbyGvQ9qH0s4VtRMOJB/X65tmDX/NMKX+W/7jP9p9Hm3EH4QXNjP6xdG77/vcAThWNdPo+29Hvgse85iShH2cLcTzmL6MGGPw/pZz5oc6/jPaNo1hAPuve+3Lhr/WsKZSy3sOWNs5ECft4YDGyza2CIieWFm5wI3uvv0AWeWgqImJhE5pMyswszOM7NU1NT2OcIJEjLMaA9ChgUzOwP4Xa5pHs7CkgJhZiMIzWVHEo6T/Bb4oLs3xVqYHDAFhIiI5KQmJhERyWnYdQA2btw4nzFjRtxliIgMK08++eQ2d68deM49hl1AzJgxg0WLFsVdhojIsGJmaweea29qYhIRkZwUECIikpMCQkREclJAiIhITgoIERHJSQEhIiI5KSBERCSnogmIna1dPLO+ceAZRUQEKKKA+P43/y8Tvn88be3tcZciIjIsFE1AvObIWibaTnY21A88s4iIFE9AlI4K96/ftX3jAHOKiAgUUUCMGDMZgLbtm2KuRERkeCiagBg1LgRE564tMVciIjI8FE1A1NROASDTrIAQERmMogmIVEUVbZRjrVvjLkVEZFgomoAAaEzUUNKxLe4yRESGhaIKiJbUGMo7t8ddhojIsFBUAdFROoaR6R1xlyEiMiwUVUCkR9QyOrMTd4+7FBGRgldUAeEjaqmhheb2jrhLEREpeHkLCDO7xcwazOz5fqabmX3LzFaa2bNmdmK+aumVrJ5AwpydDbpYTkRkIPncg7gVOGc/088F5kTD1cANeawFgLJRkwBo2qaAEBEZSN4Cwt3/CuzviPAFwI89+AdQY2aT8lUPwIgxoT+mtp3qsE9EZCBxHoOYAqzPer4hGrcPM7vazBaZ2aKtWw/+QrfqqLuNrsbNB70MEZFiEWdAWI5xOU8vcveb3X2Buy+ora096BWOGqfuNkREBivOgNgATM16Xgfk9eBAoryKdi9lw/q1+VyNiMhhIc6AuBd4V3Q206nALnfP78EBM3ZZFTN9/cDziogUuVS+FmxmtwELgXFmtgH4HFAC4O43AvcD5wErgTbgqnzVkq25fBKV3emhWJWIyLCWt4Bw98sHmO7Av+Zr/f3pKJ9ATceSoV6tiMiwU1RXUgP0jJzMeN9Ba0d33KWIiBS0oguIZM1kKqyLhq061VVEZH+KLiDKxtQB0Lh5XcyViIgUtqILiOrx0wBo2aaAEBHZn6ILiJqJMwDo3rEh3kJERApc0QVEeU3obsOb1GGfiMj+FF1AkCplh9WQatVBahGR/Sm+gACaSmqp6GiIuwwRkYJWlAHRXj6emvTB9worIlIMijIguisnMc6309HdE3cpIiIFqygDwqonMcZaaNixK+5SREQKVlEGxB/Xhbe9tX51zJWIiBSuogyIKyoeBaB585p4CxERKWBFGRCjL7sBgM7tunGQiEh/ijIgSseEG9l5o7rbEBHpT1EGBKkydiTGUtayMe5KREQKVnEGBLCrbCJVnfm9w6mIyHBWtAHRWTmF2p4tdPdk4i5FRKQgFW1A+KipTGI79Tta4y5FRKQgFW1AlI6dQan10LBpTdyliIgUpKINiOqJswDYtXlVzJWIiBSmog2ImsmzAejctibeQkREClTRBkTJmOmAroUQEelP0QYEpSPYTjVtDeqPSUQkl+INCKCtfDJ1iW1xlyEiUpDyGhBmdo6ZLTezlWb2yRzTp5vZA2b2rJn9xczq8llPX53pNBN7NtPc0T2UqxURGRbyFhBmlgS+C5wLzAcuN7P5fWb7L+DH7n4scB3w5XzVk7PG2a+hzraxdmvTUK5WRGRYyOcexCnASndf5e5dwO3ABX3mmQ88ED1+MMf0vKqYOJcS62HLuheHcrUiIsNCPgNiCrA+6/mGaFy2Z4C3Ro/fAlSZ2dg81rSXsVOPAqC5fsVQrVJEZNjIZ0BYjnHe5/lHgbPM7GngLGAjkN5nQWZXm9kiM1u0devWQ1Zg2YQjAOjZ+tIhW6aIyOEinwGxAZia9bwO2JQ9g7tvcveL3P0E4NPRuH1uFO3uN7v7AndfUFtbe+gqHDmBDiuntEmnuoqI9JXPgHgCmGNmM82sFLgMuDd7BjMbZ2a9NXwKuCWP9ezLjB1ldYxqXz/wvCIiRSZvAeHuaeBa4A/AMuBOd19iZteZ2fnRbAuB5Wa2ApgAfDFf9fSno2o6UzL1NLZ1DfWqRUQKWiqfC3f3+4H7+4z7bNbju4C78lnDQGzsbOoa/sKShiZOmDEuzlJERApKUV9JDTBi4lxKrYev3/nAwDOLiBSRog+IsdOOBOCCaR0xVyIiUliKPiBS48Kprt3L/xRzJSIihSWvxyCGhaqJtCUqqUjsc/mFiEhRK/o9CMzYVTmLSZ1raetSSIiI9FJAAD1j5zInsZEXt7TEXYqISMFQQAAVk+dTa7tYvX5D3KWIiBQMBQRQM/1YAJrWPx9zJSIihUMBASTHzwOgZ8sLMVciIlI4FBAAo6bSThlsVUCIiPRSQAAkEqy1OmazgV1tuv2oiAgoIHarmX4MRyQ2sqR+n97GRUSKkgIiUrXjeSbbDlau1ZlMIiKggNit8s1fBaBpzeKYKxERKQwKiF4TjwEg0aBTXUVEQAGxR9UEWkvGUNu6go7unrirERGJnQIiS/uYozjS1vLC5ua4SxERiZ0CIkvplOOYaxv46G1PxF2KiEjsFBBZqqYfT5mlGde1Pu5SRERip4DIYpNCn0wnlOpUVxERBUS2sXNIk2RM0zJaO3VvCBEpbgqIbMkUrWNfwTGJ1Ty7QVdUi0hxU0D0UTb9ZI6xVSxety3uUkREYqWA6KN8+slUWidbXno27lJERGI1YECY2WwzK4seLzSzD5hZTf5Li8mUkwBI1D+Fu8dcjIhIfAazB/FLoMfMjgB+AMwEfp7XquI0ZjYtjGBW53I27GyPuxoRkdgMJiAy7p4G3gJ8090/DEwazMLN7BwzW25mK83skzmmTzOzB83saTN71szOO7Dy8yCRYGXJHI5LvMTjq3fEXY2ISGwGExDdZnY58G7gvmhcyUAvMrMk8F3gXGA+cLmZze8z22eAO939BOAy4P8bbOH5dNwpr+XIxHoWrdwUdykiIrEZTEBcBZwGfNHdV5vZTOCng3jdKcBKd1/l7l3A7cAFfeZxoDp6PAooiG9kW/E7SuihaZW63BCR4jVgQLj7Unf/gLvfZmajgSp3/8oglj0FyO6zYkM0LtvngSvMbANwP/BvuRZkZleb2SIzW7R169ZBrPplujLsKE1reZZNjToOISLFaTBnMf3FzKrNbAzwDPBDM/vGIJZtOcb1PS3ocuBWd68DzgN+Ymb71OTuN7v7AndfUFtbO4hVv0yV4+isOYIFieU8tnp7/tcnIlKABtPENMrdm4CLgB+6+0nA2YN43QZgatbzOvZtQnovcCeAuz8KlAPjBrHsvCuZ9WoWJFbw5fuWxl2KiEgsBhMQKTObBFzCnoPUg/EEMMfMZppZKeEg9L195lkHvBbAzI4iBMQQtCENLDHtNEZZK+M7V+t6CBEpSoMJiOuAPwAvufsTZjYLeHGgF0Wnxl4bvXYZ4WylJWZ2nZmdH83278D7zOwZ4DbgSi+Ub+NppwJwvC/jpa0tMRcjIjL0UgPN4O6/AH6R9XwV8NbBLNzd7yccfM4e99msx0uBVw+22CE1egY9lRM5pekF/rJ8K0eMr4q7IhGRITWYg9R1ZvZrM2swsy1m9kszqxuK4mJlRnLm6ZyeWsZDyxvirkZEZMgNponph4RjB5MJp6n+Jhp3+Ju1kDHeSMOqxVx8wyNxVyMiMqQGExC17v5Dd09Hw63AEJxrWgBmLQTg1fY8TR3dsZYiIjLUBhMQ28zsCjNLRsMVQHFcHFAzFR9zBKcnnmdna1fc1YiIDKnBBMR7CKe4bgbqgYsJ3W8UBZu9kFenlkFPN+meTNzliIgMmcF0tbHO3c9391p3H+/uFxIumisOsxZS5h3M6FjG42vUu6uIFI+DvaPcRw5pFYVs5pl4IsVrkk/z4TsWx12NiMiQOdiAyNXP0uGpfBQ2/VW8PvU0O1q76MkUxnV8IiL5drABUVzfknPPZTYbmJjZrM77RKRo9BsQZtZsZk05hmbCNRHFY945ALyx9BnufnpjzMWIiAyNfgPC3avcvTrHUOXuA3bRcVgZMwvGzeOikc/xu+c209HdE3dFIiJ5d7BNTMXnyPOY07aYZOdO3vSth+OuRkQk7xQQgzX/Qsx7eGPqSba2dMZdjYhI3ikgBmvScTB6JlePe5Zd7WnWbm+NuyIRkbxSQAyWGRz9FqbteoJxiWZue3z9wK8RERnGBtPdd66zmdZHXYDPGooiC8bRoZnpgrIn+f7Dq+hM62C1iBy+BrMH8Q3gY4SuvuuAjwLfA24HbslfaQVo4rEwbi4XJR4mnXHuWdz3FtsiIoePwQTEOe5+k7s3u3uTu98MnOfudwCj81xfYTGD49/O0T3LOHt8M9/76yrdr1pEDluDCYiMmV1iZolouCRrWvF9Ox57GViCj01YxIsNLZzzzb/GXZGISF4MJiDeAbwTaIiGdwJXmFkFcG0eaytM1ZPgiLOZu/m3lCdhU2OH9iJE5LA0mO6+V7n7m919XDS82d1Xunu7u/9tKIosOCe8E2vexFtGPk9zZ5q/r1T/TCJy+BnMWUx10RlLDWa2xcx+aWZ1Q1FcwZp3HlRN5otTHqU0meCanz6pvQgROewMponph8C9hA76pgC/icYVr2QKFlxFYtWDnFy9g5bONA8sa4i7KhGRQ2owAVHr7j9093Q03ArU5rmuwnfiuyGR4nL7I+UlCb70u2V065akInIYGUxAbDOzK8wsGQ1XAGp0r5oAr3grb0r/LzdeNItVW1s5+78firsqEZFDZjAB8R7gEmAzUA9cDFyVz6KGjVd/ELpaOKvpHqrLU2xobKehuSPuqkREDonBnMW0zt3Pd/dadx/v7hcCFw1m4WZ2jpktN7OVZvbJHNP/x8wWR8MKM2s8iPcQnwlHw5w3YI/dxNyxSTLunPPNh7n0pkfjrkxE5GU72M76PjLQDGaWBL4LnAvMBy43s/nZ87j7h939eHc/Hvg28KuDrCc+TfXQto27TlzCR86ey47WLna2dsVdlYjIy3awAWGDmOcUYGV0HUUXoe+mC/Yz/+XAbQdZT3z+5WE44nXw8Dd4/yljGVGaZNW2VrbpnhEiMswdbEAM5qT/KUB2n9gbonH7MLPpwEzgz/1Mv9rMFpnZoq1btx5orfl39uegYxel/7ie2bWVpDPO2d94SNdGiMiw1m9A9NPNd5OZNROuiRhIrr2M/r4xLwPucvec/We7+83uvsDdF9TWFuAZthOPgcpx8Mi3+M27ZzN9zAga27o582sP6niEiAxb/QaEu1e5e3WOocrdU4NY9gZgatbzOqC//rEvYzg2L2X75/8FS8JfvsyE6jLGVpayfmc7jW06HiEiw1M+7yj3BDDHzGaaWSkhBO7tO5OZzSN0Gz68f2qPngEnvxee/il3vnUsD3/inxhRmmT5lhbe+K2H465OROSA5S0g3D1N6O31D8Ay4E53X2Jm15nZ+VmzXg7c7odDg/2ZHwMMfvB6RpQkmTehCjNYvrmZVVtb4q5OROSA2HD7Xl6wYIEvWrQo7jL69/j34P6Pwlt/AMdczAXf+RtLNjWRSBhHT6rm1//66rgrFJEiZGZPuvuCA3lNPpuYitOC90DpSLj7GmjbQXlJkiMnVpHuybBscxMXfvdvOnAtIsOCAuJQSyThqvvBHe7/GHe8/zTu+8AZzJtYRXePs2RTE21d6birFBEZkAIiHyYdB2d9Ap6/C57/JQC/++CZ3PdvpwPw3MYmzrtetyoVkcKmgMiX0z8CdSfDvR+ErSsAOGpSNUdPqsYMltU3s/DrD3LJjY/EXKiISG4KiHxJpuBtP4JUGdxxBXQ2A3D3tadzwtQaqitKWLO9jafWNXLxDQoJESk8Coh8GjUFRk6EbcvhnmvDcQngV//n1Tz9f19H3egK0hnn+Y27OO/6v+rgtYgUFAVEvv2fv8PrroOld8Mj39o9OpEwptRUUFGaJAMsrW9mzfZWLr7hEQWFiBQEBcRQeNUHYP6F8KfPwdJ7do++4/2nceyUURw7ZRQTqsvY0tTJcxt38cyGRh2bEJHY6UK5odLdDj++ADYthit+CTPP2GvypTc9SnNHN6u3tdHe3UNlWZJpo0dQXVHCHe8/LaaiReRwoQvlCllJBVx2G4yZCT+/BFbv3T/THe8/jfs/eCbHTKmmPJWgK51h2eZmFq3dyZu/rb6cRGToaQ9iqLU0wPXHQboD3nXvPnsSEPYmMhlnc1MH63e2AzCqooSudA/HTBnFnde8aqirFpFh7mD2IBQQcWhpgB+9GRrXwWU/h9n/lHO2S296lCWbdlFTUcqW5g66e5zykgQTq8sZN7KMZCLcckNNUCIyEAXEcNLSEI5JbHsRLrwBjn3bfmd/242P8NzGXSTNaO3qIWEwtrKUlq4ejplcjZnCQkT6p4AYbtob4ZvHQGdTOBX2VR8A6/9235fe9CjuTktnmq3NnWxv7SLjUJZKMG5kGTvbuniFwkJEclBADEfpTvj1+2HJr+Gkq+Dcr4arr/ej9zqJnozz/KZdlCYTNHWEDgBLkkZNRQmjKkqorijhxYYW5k+qVliIFDkFxHCVycAD/w/+/k2YfELoomP09EG9tDcsutIZltY3UVmWYld7Nz2Z8LkmDCZWl1NVnmJjYztHTx61+7UKDZHioYAY7pbdB3e+Cwy45Kdw5HkH9PLesAjNUD3sau+mflc7mayPeERpkqryECJHTaymNLX3mc4KDZHDkwLicLBjFdx4OnS1wonvhjd8CcpGHtSiLr3pUZbWNzFvQhUtnWlWbW2hvCRJS2d6d2iUJI3KshSVpSl2tHZyZBQaS+ub1DQlchhRQBwu0p3w4Bfh79dDqhzedQ9MO/VlL7Z3DyPj4cZF40aW0dqZprUrTUd3Zvd8JUkj4zChuoyKkuTu4YUtzcyfVL17PoWHyPChgDjcrH0Efnwh9HTC6R+Ghf8BqdJDsui+HQL2ZJwl9U2MH1lGa1eaHdEZUtnMoLq8hIqSJDvbuphdW0lFSZIV0YHwXgoOkcKjgDgcdTbD7z8FT/8EUhVw+c9h9mvysqrs0Fha38SRE6roSPfQ3tVDe3eGLU0dlKYStHf30PfPZkRpktJUgrJUgtJkgq0tncwaN5KyVIIXG5r3OjjeS0EiMnQUEIezF/8UbjyU7oAj3wRv+CKMnjFkq+89njF/UjUeNVHVja6gvbuH+l0dVJal6Epn6Exndp9Bla0kaZQmE5SmEpQkEzS2dzOlpoLSpLF+ZztHTqwilbDd13D0rquXwkTk5VFAHO66O+DR74TjE+7wymvgzI9C5bjYSsoOjl7pjLOsvompoyvoTGfY1Nge+pLqcbp7MnSlM6RzhIgBJckEpSmjozvD2MpSSlIJtjZ3Mm3MCFIJY+2ONuZNqCKVNBIKE5FBU0AUi6ZN8Jcvw1M/BkvCWZ+A0/71oM92yqdcAQKwZNMujqgdSVdPhlXbWplQVU5XT2Z3gLR0pjGznHsjvRIGqWSCdE+GkWUpSpIJmjq6GV9VRiqRIJU0UomwhzJn/EhSSWP55j3NXQoWKSYKiGKzdQX88Bxo2w6VtfCqf4OTroTyfdv7C1muO+j1fnlnMs7SzU3MHFtJOuOs2d7KxOpy0j1OOhP2RHa2dVNekiDd43SmMznWsLdUwkglja50hqryEpIGyWSCkoSxvbWLutEVu4Nldm0lCQt7K4kE/QaMwkYKXcEFhJmdA1wPJIHvu/tXcsxzCfB5wIFn3P3t+1umAiKH9U/An78Aqx8KexSvfH9ofhrk1djDyf7CpPfxUROrSGd8d4is3tbKpFEVpDMZNu/qYNSIUtI9GZo60pSnEvS405NxunsG92/BCLeMzWSc0lSCZMLo6O5hZFmKRMJIRoHS2N4Vet01Y2tLJ5NGlZMwI5kwEgbrdrYze1wIoJe2tTBvQhVJM5ZtbtrnoL6uS5GXq6ACwsySwArgdcAG4AngcndfmjXPHOBO4DXuvtPMxrt7w/6Wq4DYj02L4dHvwnO/AByOfguccjVMO22/nQAWi77NXX1/9bs7S+ubmFU7knRPCJa60SPoyTgZD8Pmpk7GVpaScWdHaxfV5SVk3GnqSFNRkgzzZZweD4FjhF8+ByphZIWJ0ZnuobIsRdKMls40NSNKdu/ZJBOwtaVrdwDV72qnrqYCM8MsBNq6HW3MGFdJAmP19lZm11ZiGC9tbWHOhJEkzDDAzHghK6ByNQ/2tw21F1XYCi0gTgM+7+5viJ5/CsCKcfxdAAAT9UlEQVTdv5w1z9eAFe7+/cEuVwExCLs2wmM3hrDwHigZEXqLPfZSKK8e+PUyoIHCJntcxsNB+znjq3YHTU/GWb2tlamjR5BxZ/3OdiZWl5NxZ0tTB2Mqy/aat6kjzYjSJJmM097dQyqZILM7uPLzHhMGGQ9NcokobHrDqqI0ScKMtq4eqspSmEFzRwguMyNB+E2yvTXsRRmwraWT8VXlIbQMwNjS1MGkUWFcfWMHU0ZXYMCGxnamjh6xO+DAWLejlRljKzGDNdvbmDWuMkwxdofdEePDcbiVDSH4wpRgRUMz8yZUYcDyLc3Mm1iNAS9sbuKoidWwe12wbHPzoD7b/T0eaN6DWRYcfPAWWkBcDJzj7v8cPX8n8Ep3vzZrnrsJexmvJjRDfd7df59jWVcDVwNMmzbtpLVr1+al5sNOVys8/0t44gdQvxgsASe+Cxa8FyYdG3d1MoC+16UMtOczd0IVmYzzYkMLs2orcQ97L+7O6u1tTB8zAndn7Y426kaHxxsb25k0qhz3EAbuzpbmTsZVluKEL/XRI8IeU5gnhNXIsiQZh9bONOUlSdyhI91DSSKB47vX3ZNxzNjnupnhIGFhjyqTcVLJKKaiPa2ungxlqQRGCMzykiQGtHf3MKI0tddy2rvSu8e1ZT3uTaO2zjQjylIY0NoVmioNaO5MU12e2mu9tVVl/P5DZx7U+zmYgEgNPMtBy9Wm0ffPJAXMARYCdcDDZvYKd2/c60XuNwM3Q9iDOPSlHqZKK0MgnPgu2Pgk/PyycObTk7dC3Slw8nth/oVQUh53pZJD3E00l970KC2daWZGv9R7La1v4siJB/7recmmXdF1NHuC64UtzcwdX4XjvLgl7AG4w8qtLbv3EML82SEH63a0MjV63LusjY3tTKmpwB02NbYzqaYC8N7/qN/VwcRR5eCwuamDCdWhW/3NTZ1MqCoLX07RurY2dzKmsgwnNCXWjCgF92hdsKu9mxGlKdzDqdulyQQOWDpDou83X7T3lf3Y+0yHPQGdzmRCHR5Ouuh9/wA1PSUH8UkevHwGxAZgatbzOmBTjnn+4e7dwGozW04IjCfyWFdxmnISfOxFaNsBz9wGD3wh3Ifi95+C4y6HYy+BScfpWIXsFndAxeXSmx6lqSPN9LEjAGjr6tkdVr2W1jcxJ2rOWlrfxLyJVbsfH5UjJI/KPoniZTYxDaV8NjGlCM1HrwU2Er703+7uS7LmOYdw4PrdZjYOeBo43t2397dcHYM4RNzDWU9P/ACW/w4y3VBSAWf8OxzztiG9SltE8q+gmpjcPW1m1wJ/IBxfuMXdl5jZdcAid783mvZ6M1sK9AAf2184yCFkBrMWhqFtByy9G/70Ofjzf4ahrBpe+1k48o1QPTneWkUkFrpQTva2cy08fxf89b+guy2MK62E066FuefApOMhkdj/MkSk4BTUWUz5ooAYIu6w9QVY8XtY/ntY/48wPlkKx10WwmLWwhAeIlLwFBCSP63bYeWfwvGKZb8J11dgMOd1MPcNITBG1cVdpYj0QwEhQyPdBesegRV/gEW3hC7IASYcA/POCWEx+QRIJOOtU0R2U0DI0HOHbS/Cit+F4xadTXumVYwJPc3OPANqj9KxC5EYKSAkfm07YOUD8MdPQ8euPXsXiVS40dHMM2DmWTD2CF1zITKECuo0VylSI8bAsW8LA0DjOlj9MKx5OHT7sfTuMD5ZCvMvgJlnwowzwnUXCgyRgqI9CBk67rBjVQiLB74Q9jAy3WFasgyOuTiExcwzdMBb5BDTHoQUNjMYOzsMJ10ZnUq7PATG6ofg2Ttg8c/CvMkymHcu1J0chknHqc8okSGmgJD4mMH4I8Nwyvsgk4Etz8Pav8OGJ8LQ2ySFhTOjpp4ShcYCqJmuZimRPFITkxS25i17wmLRLXufJZUoCddh1C0IvdNOPqEg78stUgh0FpMc/nrS0LA0Co1FsOTXkG7fM72kMvRM29s0NfYInV4rggJCilXbjnC/iw1PwGM3QWdzdKU34fTa0io49ZqwpzHpBKgcG2+9IjFQQIhAOJax/UVY/3gIjed+safjQQgHwEsroXQknPvVcHe96ik6niGHNQWESH86m2HjU1D/DGx+Fpbdt3fTVCIVAqO0El7/BZh4HIyZpeYpOWzoNFeR/pRVwayzwtCrqxW2LAmhUf8MLPkVNG2Eu96T9brqcMe9SceFPY3aIyE5tLd9FImL9iBEsqW7Qjfnm58NN07qaoWuFvDMnnlKK+GoC2DC0XuGkePjq1lkENTEJJIPmZ5wBXj9M/Cnz4bQ6G6Dnq6smQzKq+H4d8D4+TBhftjb0P0ypECoiUkkHxJJGDcnDMdcvGd867bQRNWwFB7+H+huhcdu3HtvI1UOR5wdQmP8UWFvY8xsSOqfnhQ+7UGIHEqZHti5JoTGlqXQsARW/HHvA+IYlIyA0hHhVq4Tjg4BUj1ZZ1JJ3qiJSaRQdXfAtuVRaCyFp34CHY1A1r+/RDJcEZ69tzF+PlTUxFa2HD7UxCRSqErKozOhjgvPX/+F8P/2ndCwbE9T1XN3wbpH935tsjRcIX7iO/eExri56rxQ8k57ECKFxj2cbtu7t9GwNFy30d3GXnscyTIoqYBXvDVcszFmVugpt2a6wkP2oT0IkcOBWbgfxqg6mPv6PeN7usPZVFuWhFNxd6wKw1M/gkx672Uky2DaqXsHx5hZ4cZMJRVD+nZk+FJAiAwXyRKonReGvtp2wI7VsOOlcDOmdHu4cnzt33OHx9RT9g6OMbNg9Mxw4FwkooAQORyMGBOGupNCb7bZ2ndGexurw/8fvxnqF8O6f+y5o1+vZGk4NffoC/cEx5jZMGamrukoQjoGIVLM2hthZxQc26PwaNsOlug/PFIVcOr7o+CYFcKjrCqe+mXQCu40VzM7B7geSALfd/ev9Jl+JfB1YGM06jvu/v39LVMBITJEOppCeGx/ac8eyLJ7Q8eH9PneqByf1Vw1Mys8ZoUrzCV2BRUQZpYEVgCvAzYATwCXu/vSrHmuBBa4+7WDXa4CQqQAdLbsHR6P3RiuLE8k+3RBQrjzX0m053HK1VGARMc/ykfFU38RKrSzmE4BVrr7KgAzux24AFi631eJSOErGwkTjwkDwBkf2TOtq3XP8Y4dq8KB8yX3QOtWePA/+yzIQjfr884NwTF65p69kMpaXVkes3wGxBRgfdbzDcArc8z3VjM7k7C38WF3X993BjO7GrgaYNq0aXkoVUQOmdJKmPiKMPQ6/9vh/11toSuS7PDYsQqW3gM9nXsvx5KQKgt7Hie8IytAZkJ1nfqzGgL53MK5or9ve9ZvgNvcvdPMrgF+BLxmnxe53wzcDKGJ6VAXKiJDpHRE6Ol2wvx9p6U7oXFd2PvYuTr8/9nbwwWCj3ybfb4+UuUw4/Q9odF7qu7o6brW4xDJZ0BsAKZmPa8DNmXP4O7bs55+D/hqHusRkUKWKtvTa26vc6PzWjIZaN60p+nqoa9BuiN0S/LSg3vuQb6bhTOr5p4TXXQ4Jex19D4ur1Hz1SDkMyCeAOaY2UzCWUqXAW/PnsHMJrl7ffT0fGBZHusRkeEqkdhzdfnMM+Ckd++Z5h4uFOzd69i5Gh7/XgiQpfdEB81zNDyUVMKshWGPo2Y61Ezb87hs5BC9scKWt4Bw97SZXQv8gXCa6y3uvsTMrgMWufu9wAfM7HwgDewArsxXPSJymDKDyrFhqItO0jnr43umZ3qgpSH0b7VrPezaAI9+NwTIqgejPq72WWg4lnLE2VFoTIOaGeHxqKlF09eVLpQTkeLmHi4O3LkWGteE/z92UwiQ3mEfFvYy5p67955HzbSwl1OA9y0vqOsg8kUBISJDKpOBls1RgKwNB9If/144qJ7u2PfsKwj9XaXKYN550d5H1hBTgBTadRAiIsNfIhHu9lc9GaafFsZlN2H1dIfmq51rQxNW47o9w5JfDyJApu4dINV1kCodmvc2AAWEiMjLkSwJ3aiPnpF7em+AZAfHgAES9Xs179y9w2PCK6ByXD7fzV4UECIi+TSoANm0d3AsuiUcF1ly994BMmY2fOCpoagaUECIiMQrWRIOco+evmfcP31qz+PsABkzc0hLU0CIiBSyXAEyRBJDvkYRERkWFBAiIpKTAkJERHJSQIiISE4KCBERyUkBISIiOSkgREQkJwWEiIjkNOx6czWzrcDag3z5OGDbISznUFN9B6+Qa4PCrq+QawPV93Jk1zbd3WsP5MXDLiBeDjNbdKDd3Q4l1XfwCrk2KOz6Crk2UH0vx8utTU1MIiKSkwJCRERyKraAuDnuAgag+g5eIdcGhV1fIdcGqu/leFm1FdUxCBERGbxi24MQEZFBUkCIiEhORRMQZnaOmS03s5Vm9smYa5lqZg+a2TIzW2JmH4zGf97MNprZ4mg4L8Ya15jZc1Edi6JxY8zsT2b2YvT/0THVNi9rGy02syYz+1Cc28/MbjGzBjN7Pmtczu1lwbeiv8VnzezEGGr7upm9EK3/12ZWE42fYWbtWdvwxnzWtp/6+v0szexT0bZbbmZviKG2O7LqWmNmi6PxcWy7/r5LDs3fnrsf9gOQBF4CZgGlwDPA/BjrmQScGD2uAlYA84HPAx+Ne3tFda0BxvUZ9zXgk9HjTwJfLYA6k8BmYHqc2w84EzgReH6g7QWcB/wOMOBU4LEYans9kIoefzWrthnZ88W47XJ+ltG/k2eAMmBm9O86OZS19Zn+38BnY9x2/X2XHJK/vWLZgzgFWOnuq9y9C7gduCCuYty93t2fih43A8uAKXHVcwAuAH4UPf4RcGGMtfR6LfCSux/s1fWHhLv/FdjRZ3R/2+sC4Mce/AOoMbNJQ1mbu//R3dPR038Adfla/0D62Xb9uQC43d073X01sJLw73vIazMzAy4BbsvX+geyn++SQ/K3VywBMQVYn/V8AwXyhWxmM4ATgMeiUddGu363xNWEE3Hgj2b2pJldHY2b4O71EP4wgfGxVbfHZez9D7RQth/0v70K7e/xPYRflb1mmtnTZvaQmZ0RV1Hk/iwLadudAWxx9xezxsW27fp8lxySv71iCQjLMS7283vNbCTwS+BD7t4E3ADMBo4H6gm7r3F5tbufCJwL/KuZnRljLTmZWSlwPvCLaFQhbb/9KZi/RzP7NJAGfhaNqgemufsJwEeAn5tZdQyl9fdZFsy2Ay5n7x8nsW27HN8l/c6aY1y/269YAmIDMDXreR2wKaZaADCzEsIH+jN3/xWAu29x9x53zwDfI4+7zgNx903R/xuAX0e1bOndHY3+3xBXfZFzgafcfQsU1vaL9Le9CuLv0czeDbwJeIdHDdRR08326PGThDb+uUNd234+y0LZdingIuCO3nFxbbtc3yUcor+9YgmIJ4A5ZjYz+tV5GXBvXMVEbZc/AJa5+zeyxme3Bb4FeL7va4eCmVWaWVXvY8IBzecJ2+zd0WzvBu6Jo74se/2CK5Ttl6W/7XUv8K7ojJJTgV29zQFDxczOAT4BnO/ubVnja80sGT2eBcwBVg1lbdG6+/ss7wUuM7MyM5sZ1ff4UNcHnA284O4bekfEse36+y7hUP3tDeUR9zgHwtH7FYRU/3TMtZxO2K17FlgcDecBPwGei8bfC0yKqb5ZhDNFngGW9G4vYCzwAPBi9P8xMW7DEcB2YFTWuNi2HyGo6oFuwq+09/a3vQi7+d+N/hafAxbEUNtKQlt079/fjdG8b40+82eAp4A3x7Tt+v0sgU9H2245cO5Q1xaNvxW4ps+8cWy7/r5LDsnfnrraEBGRnIqliUlERA6QAkJERHJSQIiISE4KCBERyUkBISIiOSkgREQkJwWEyCCY2fF9upw+3w5Rt/EWuiofcSiWJXIo6ToIkUEwsysJFxVdm4dlr4mWve0AXpN0955DXYtINu1ByGElumnLMjP7XnQDlT+aWUU/8842s99HPdY+bGZHRuPfZmbPm9kzZvbXqHuW64BLoxvBXGpmV5rZd6L5bzWzG6Ibt6wys7OiHkiXmdmtWeu7wcwWRXX9v2jcB4DJwINm9mA07nILN2t63sy+mvX6FjO7zsweA04zs6+Y2dKox9P/ys8WlaKW70vBNWgYyoFw05Y0cHz0/E7gin7mfQCYEz1+JfDn6PFzwJTocU30/yuB72S9dvdzQrcLtxO6MbgAaAKOIfwAezKrlt7uDpLAX4Bjo+driG7ORAiLdUAtkAL+DFwYTXPgkt5lEbqasOw6NWg4lIP2IORwtNrdF0ePnySExl6i7pFfBfzCwi0jbyLcnQvg78CtZvY+wpf5YPzG3Z0QLlvc/TkPPZEuyVr/JWb2FPA0cDThzl99nQz8xd23erihz88IdzUD6CH02gkhhDqA75vZRUDbPksSeZlScRcgkgedWY97gFxNTAmg0d2P7zvB3a8xs1cCbwQWm9k+8+xnnZk+688Aqajn0Y8CJ7v7zqjpqTzHcnL119+rw6PjDu6eNrNTCHfUuwy4FnjNIOoUGTTtQUhR8nBTldVm9jbYfTP346LHs939MXf/LLCN0H9+M+GevwerGmgFdpnZBMK9LHplL/sx4CwzGxd1HX058FDfhUV7QKPc/X7gQ4Qb64gcUtqDkGL2DuAGM/sMUEI4jvAM8HUzm0P4Nf9ANG4d8MmoOerLB7oid3/GzJ4mNDmtIjRj9boZ+J2Z1bv7P5nZp4AHo/Xf7+657rtRBdxjZuXRfB8+0JpEBqLTXEVEJCc1MYmISE5qYpLDnpl9F3h1n9HXu/sP46hHZLhQE5OIiOSkJiYREclJASEiIjkpIEREJCcFhIiI5PT/Ax8R4yHET2kEAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x14c877d5710>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "cvresult = pd.DataFrame.from_csv('1_nestimators.csv')\n",
    "\n",
    "# plot\n",
    "test_means = cvresult['test-mlogloss-mean']\n",
    "test_stds = cvresult['test-mlogloss-std'] \n",
    "        \n",
    "train_means = cvresult['train-mlogloss-mean']\n",
    "train_stds = cvresult['train-mlogloss-std'] \n",
    "\n",
    "x_axis = range(0, cvresult.shape[0])\n",
    "        \n",
    "pyplot.errorbar(x_axis, test_means, yerr=test_stds ,label='Test')\n",
    "pyplot.errorbar(x_axis, train_means, yerr=train_stds ,label='Train')\n",
    "pyplot.title(\"XGBoost n_estimators vs Log Loss\")\n",
    "pyplot.xlabel( 'n_estimators' )\n",
    "pyplot.ylabel( 'Log Loss' )\n",
    "pyplot.savefig( 'n_estimators4_1.png' )\n",
    "\n",
    "pyplot.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>test-mlogloss-mean</th>\n",
       "      <th>test-mlogloss-std</th>\n",
       "      <th>train-mlogloss-mean</th>\n",
       "      <th>train-mlogloss-std</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>n_estimators</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.041957</td>\n",
       "      <td>0.000657</td>\n",
       "      <td>1.041060</td>\n",
       "      <td>0.000874</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.992393</td>\n",
       "      <td>0.000762</td>\n",
       "      <td>0.990396</td>\n",
       "      <td>0.001215</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.950408</td>\n",
       "      <td>0.001488</td>\n",
       "      <td>0.947473</td>\n",
       "      <td>0.001137</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.913990</td>\n",
       "      <td>0.001579</td>\n",
       "      <td>0.910169</td>\n",
       "      <td>0.001065</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.882584</td>\n",
       "      <td>0.001963</td>\n",
       "      <td>0.877826</td>\n",
       "      <td>0.001323</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              test-mlogloss-mean  test-mlogloss-std  train-mlogloss-mean  \\\n",
       "n_estimators                                                               \n",
       "0                       1.041957           0.000657             1.041060   \n",
       "1                       0.992393           0.000762             0.990396   \n",
       "2                       0.950408           0.001488             0.947473   \n",
       "3                       0.913990           0.001579             0.910169   \n",
       "4                       0.882584           0.001963             0.877826   \n",
       "\n",
       "              train-mlogloss-std  \n",
       "n_estimators                      \n",
       "0                       0.000874  \n",
       "1                       0.001215  \n",
       "2                       0.001137  \n",
       "3                       0.001065  \n",
       "4                       0.001323  "
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cvresult.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 第二步：调整树的参数：max_depth & min_child_weight，subsample和colsample_bytree\n",
    "(粗调，参数的步长为2；下一步是在粗调最佳参数周围，将步长降为1，进行精细调整)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "from xgboost import XGBClassifier\n",
    "import xgboost as xgb\n",
    "\n",
    "import pandas as pd \n",
    "import numpy as np\n",
    "\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "\n",
    "from sklearn.metrics import log_loss\n",
    "\n",
    "from matplotlib import pyplot\n",
    "import seaborn as sns\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 第一轮参数调整得到的n_estimators最优值（193），其余参数继续默认值"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "用交叉验证评价模型性能时，用scoring参数定义评价指标。评价指标是越高越好，因此用一些损失函数当评价指标时，需要再加负号，如neg_log_loss，neg_mean_squared_error 详见sklearn文档：http://scikit-learn.org/stable/modules/model_evaluation.html#log-loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 同时调两个参数比较慢，所以依次调max_depth和min_child_weight\n",
    "如果最佳值在边缘还要在这个数两边再取几个数细调"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 调max_depth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'max_depth': range(4, 10, 2)}"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#max_depth 建议3-10， min_child_weight=1／sqrt(ratio_rare_event) =5.5\n",
    "max_depth = range(4,10,2)\n",
    "#min_child_weight = range(1,6,2)\n",
    "#param_test2 = dict(max_depth=max_depth, min_child_weight=min_child_weight)\n",
    "param_test2_1 = dict(max_depth=max_depth)\n",
    "param_test2_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E:\\Anaconda3\\lib\\site-packages\\sklearn\\model_selection\\_search.py:761: DeprecationWarning: The grid_scores_ attribute was deprecated in version 0.18 in favor of the more elaborate cv_results_ attribute. The grid_scores_ attribute will not be available from 0.20\n",
      "  DeprecationWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([mean: -0.60029, std: 0.00189, params: {'max_depth': 4},\n",
       "  mean: -0.59955, std: 0.00316, params: {'max_depth': 6},\n",
       "  mean: -0.61430, std: 0.00179, params: {'max_depth': 8}],\n",
       " {'max_depth': 6},\n",
       " -0.5995472795668911)"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xgb2_1 = XGBClassifier(\n",
    "        learning_rate=0.1,\n",
    "        n_estimators=193,  #第一轮参数调整得到的n_estimators最优值\n",
    "        max_depth=5,\n",
    "        min_child_weight=1,\n",
    "        gamma=0,\n",
    "        subsample=0.6,\n",
    "        colsample_bytree=0.8,\n",
    "        colsample_bylevel=0.7,\n",
    "        objective='multi:softprob',\n",
    "        seed=3\n",
    "        )\n",
    "\n",
    "gsearch2_1 = GridSearchCV(xgb2_1, param_grid = param_test2_1, scoring='neg_log_loss',n_jobs=-1, cv=3)\n",
    "gsearch2_1.fit(train_x , train_y)\n",
    "\n",
    "gsearch2_1.grid_scores_, gsearch2_1.best_params_,     gsearch2_1.best_score_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 调min_child_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'min_child_weight': range(1, 6, 2)}"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#min_child_weight=1／sqrt(ratio_rare_event) =5.5\n",
    "min_child_weight = range(1,6,2)\n",
    "param_test2_2 = dict(min_child_weight=min_child_weight)\n",
    "param_test2_2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E:\\Anaconda3\\lib\\site-packages\\sklearn\\model_selection\\_search.py:761: DeprecationWarning: The grid_scores_ attribute was deprecated in version 0.18 in favor of the more elaborate cv_results_ attribute. The grid_scores_ attribute will not be available from 0.20\n",
      "  DeprecationWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([mean: -0.59955, std: 0.00316, params: {'min_child_weight': 1},\n",
       "  mean: -0.59892, std: 0.00208, params: {'min_child_weight': 3},\n",
       "  mean: -0.59907, std: 0.00199, params: {'min_child_weight': 5}],\n",
       " {'min_child_weight': 3},\n",
       " -0.5989249048719666)"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xgb2_2 = XGBClassifier(\n",
    "        learning_rate=0.1,\n",
    "        n_estimators=193,  #第一轮参数调整得到的n_estimators最优值\n",
    "        max_depth=6,  #上面得到的max_depth\n",
    "        min_child_weight=1,\n",
    "        gamma=0,\n",
    "        subsample=0.6,\n",
    "        colsample_bytree=0.8,\n",
    "        colsample_bylevel=0.7,\n",
    "        objective='multi:softprob',\n",
    "        seed=3\n",
    "        )\n",
    "\n",
    "gsearch2_2 = GridSearchCV(xgb2_2, param_grid = param_test2_2, scoring='neg_log_loss',n_jobs=-1, cv=3)\n",
    "gsearch2_2.fit(train_x , train_y)\n",
    "\n",
    "gsearch2_2.grid_scores_, gsearch2_2.best_params_,     gsearch2_2.best_score_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 调整subsample和colsample_bytree，colsample_bylevel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'colsample_bytree': [0.6, 0.7, 0.8, 0.9],\n",
       " 'subsample': [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]}"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "subsample = [i/10.0 for i in range(3,9)]\n",
    "colsample_bytree = [i/10.0 for i in range(6,10)]\n",
    "param_test2_3 = dict(subsample=subsample, colsample_bytree=colsample_bytree)\n",
    "param_test2_3\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E:\\Anaconda3\\lib\\site-packages\\sklearn\\model_selection\\_search.py:761: DeprecationWarning: The grid_scores_ attribute was deprecated in version 0.18 in favor of the more elaborate cv_results_ attribute. The grid_scores_ attribute will not be available from 0.20\n",
      "  DeprecationWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([mean: -0.60541, std: 0.00299, params: {'colsample_bytree': 0.6, 'subsample': 0.3},\n",
       "  mean: -0.60303, std: 0.00199, params: {'colsample_bytree': 0.6, 'subsample': 0.4},\n",
       "  mean: -0.60063, std: 0.00233, params: {'colsample_bytree': 0.6, 'subsample': 0.5},\n",
       "  mean: -0.59731, std: 0.00258, params: {'colsample_bytree': 0.6, 'subsample': 0.6},\n",
       "  mean: -0.59659, std: 0.00248, params: {'colsample_bytree': 0.6, 'subsample': 0.7},\n",
       "  mean: -0.59579, std: 0.00288, params: {'colsample_bytree': 0.6, 'subsample': 0.8},\n",
       "  mean: -0.60649, std: 0.00259, params: {'colsample_bytree': 0.7, 'subsample': 0.3},\n",
       "  mean: -0.60325, std: 0.00318, params: {'colsample_bytree': 0.7, 'subsample': 0.4},\n",
       "  mean: -0.59889, std: 0.00183, params: {'colsample_bytree': 0.7, 'subsample': 0.5},\n",
       "  mean: -0.59881, std: 0.00411, params: {'colsample_bytree': 0.7, 'subsample': 0.6},\n",
       "  mean: -0.59771, std: 0.00246, params: {'colsample_bytree': 0.7, 'subsample': 0.7},\n",
       "  mean: -0.59684, std: 0.00276, params: {'colsample_bytree': 0.7, 'subsample': 0.8},\n",
       "  mean: -0.60695, std: 0.00239, params: {'colsample_bytree': 0.8, 'subsample': 0.3},\n",
       "  mean: -0.60432, std: 0.00546, params: {'colsample_bytree': 0.8, 'subsample': 0.4},\n",
       "  mean: -0.60066, std: 0.00342, params: {'colsample_bytree': 0.8, 'subsample': 0.5},\n",
       "  mean: -0.59892, std: 0.00208, params: {'colsample_bytree': 0.8, 'subsample': 0.6},\n",
       "  mean: -0.59578, std: 0.00301, params: {'colsample_bytree': 0.8, 'subsample': 0.7},\n",
       "  mean: -0.59741, std: 0.00286, params: {'colsample_bytree': 0.8, 'subsample': 0.8},\n",
       "  mean: -0.60700, std: 0.00181, params: {'colsample_bytree': 0.9, 'subsample': 0.3},\n",
       "  mean: -0.60331, std: 0.00375, params: {'colsample_bytree': 0.9, 'subsample': 0.4},\n",
       "  mean: -0.60099, std: 0.00189, params: {'colsample_bytree': 0.9, 'subsample': 0.5},\n",
       "  mean: -0.59989, std: 0.00270, params: {'colsample_bytree': 0.9, 'subsample': 0.6},\n",
       "  mean: -0.59622, std: 0.00347, params: {'colsample_bytree': 0.9, 'subsample': 0.7},\n",
       "  mean: -0.59605, std: 0.00189, params: {'colsample_bytree': 0.9, 'subsample': 0.8}],\n",
       " {'colsample_bytree': 0.8, 'subsample': 0.7},\n",
       " -0.5957767634924883)"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xgb2_3 = XGBClassifier(\n",
    "        learning_rate=0.1,\n",
    "        n_estimators=193,  #第一轮参数调整得到的n_estimators最优值\n",
    "        max_depth=6,  #上面得到的max_depth\n",
    "        min_child_weight=3,\n",
    "        gamma=0,\n",
    "        subsample=0.6,\n",
    "        colsample_bytree=0.8,\n",
    "        colsample_bylevel=0.7,\n",
    "        objective='multi:softprob',\n",
    "        seed=3\n",
    "        )\n",
    "\n",
    "gsearch2_3 = GridSearchCV(xgb2_3, param_grid = param_test2_3, scoring='neg_log_loss',n_jobs=-1, cv=3)\n",
    "gsearch2_3.fit(train_x , train_y)\n",
    "\n",
    "gsearch2_3.grid_scores_, gsearch2_3.best_params_,     gsearch2_3.best_score_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 第三步：正则参数调优\n",
    "有两个正则参数reg_alpha和reg_lambda,为了快些依然是一次调一个"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 先调reg_alpha"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'reg_alpha': [0, 0.1, 1, 2]}"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reg_alpha = [ 0, 0.1, 1, 2]    #default = 0, 测试0.1,1，1.5，2\n",
    "\n",
    "param_test3_1 = dict(reg_alpha=reg_alpha)\n",
    "param_test3_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E:\\Anaconda3\\lib\\site-packages\\sklearn\\model_selection\\_search.py:761: DeprecationWarning: The grid_scores_ attribute was deprecated in version 0.18 in favor of the more elaborate cv_results_ attribute. The grid_scores_ attribute will not be available from 0.20\n",
      "  DeprecationWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([mean: -0.59578, std: 0.00301, params: {'reg_alpha': 0},\n",
       "  mean: -0.59579, std: 0.00194, params: {'reg_alpha': 0.1},\n",
       "  mean: -0.59642, std: 0.00264, params: {'reg_alpha': 1},\n",
       "  mean: -0.59755, std: 0.00326, params: {'reg_alpha': 2}],\n",
       " {'reg_alpha': 0},\n",
       " -0.5957767634924883)"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xgb3_1 = XGBClassifier(\n",
    "        learning_rate=0.1,\n",
    "        n_estimators=193,  #第一轮参数调整得到的n_estimators最优值\n",
    "        max_depth=6,  #上面得到的max_depth\n",
    "        min_child_weight=3,\n",
    "        gamma=0,\n",
    "        subsample=0.7,\n",
    "        colsample_bytree=0.8,\n",
    "        colsample_bylevel=0.7,\n",
    "        objective='multi:softprob',\n",
    "        seed=3\n",
    "        )\n",
    "\n",
    "gsearch3_1 = GridSearchCV(xgb3_1, param_grid = param_test3_1, scoring='neg_log_loss',n_jobs=-1, cv=3)\n",
    "gsearch3_1.fit(train_x , train_y)\n",
    "\n",
    "gsearch3_1.grid_scores_, gsearch3_1.best_params_,     gsearch3_1.best_score_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最佳reg_alpha是默认值0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 调reg_lambda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'reg_lambda': [0.1, 0.5, 1, 2]}"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reg_lambda = [ 0.1, 0.5, 1, 2]      #default = 1，测试0.1， 0.5， 1，2\n",
    "\n",
    "param_test3_2 = dict(reg_lambda=reg_lambda)\n",
    "param_test3_2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E:\\Anaconda3\\lib\\site-packages\\sklearn\\model_selection\\_search.py:761: DeprecationWarning: The grid_scores_ attribute was deprecated in version 0.18 in favor of the more elaborate cv_results_ attribute. The grid_scores_ attribute will not be available from 0.20\n",
      "  DeprecationWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([mean: -0.59734, std: 0.00228, params: {'reg_lambda': 0.1},\n",
       "  mean: -0.59682, std: 0.00220, params: {'reg_lambda': 0.5},\n",
       "  mean: -0.59578, std: 0.00301, params: {'reg_lambda': 1},\n",
       "  mean: -0.59708, std: 0.00196, params: {'reg_lambda': 2}],\n",
       " {'reg_lambda': 1},\n",
       " -0.5957767634924883)"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xgb3_2 = XGBClassifier(\n",
    "        learning_rate=0.1,\n",
    "        n_estimators=193,  #第一轮参数调整得到的n_estimators最优值\n",
    "        max_depth=6,  #上面得到的max_depth\n",
    "        min_child_weight=3,\n",
    "        gamma=0,\n",
    "        subsample=0.7,\n",
    "        colsample_bytree=0.8,\n",
    "        colsample_bylevel=0.7,\n",
    "        objective='multi:softprob',\n",
    "        seed=3\n",
    "        )\n",
    "\n",
    "gsearch3_2 = GridSearchCV(xgb3_2, param_grid = param_test3_2, scoring='neg_log_loss',n_jobs=-1, cv=3)\n",
    "gsearch3_2.fit(train_x , train_y)\n",
    "\n",
    "gsearch3_2.grid_scores_, gsearch3_2.best_params_,     gsearch3_2.best_score_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最佳reg_lambda为默认值1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>bathrooms</th>\n",
       "      <th>bedrooms</th>\n",
       "      <th>price</th>\n",
       "      <th>price_bathrooms</th>\n",
       "      <th>price_bedrooms</th>\n",
       "      <th>room_diff</th>\n",
       "      <th>room_num</th>\n",
       "      <th>Year</th>\n",
       "      <th>Month</th>\n",
       "      <th>Day</th>\n",
       "      <th>...</th>\n",
       "      <th>walk</th>\n",
       "      <th>walls</th>\n",
       "      <th>war</th>\n",
       "      <th>washer</th>\n",
       "      <th>water</th>\n",
       "      <th>wheelchair</th>\n",
       "      <th>wifi</th>\n",
       "      <th>windows</th>\n",
       "      <th>work</th>\n",
       "      <th>interest_level</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>3795</td>\n",
       "      <td>1897.500000</td>\n",
       "      <td>1897.500000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>2016</td>\n",
       "      <td>6</td>\n",
       "      <td>28</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2.0</td>\n",
       "      <td>3</td>\n",
       "      <td>5500</td>\n",
       "      <td>1833.333333</td>\n",
       "      <td>1375.000000</td>\n",
       "      <td>-1.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>2016</td>\n",
       "      <td>6</td>\n",
       "      <td>4</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.0</td>\n",
       "      <td>2</td>\n",
       "      <td>3100</td>\n",
       "      <td>1550.000000</td>\n",
       "      <td>1033.333333</td>\n",
       "      <td>-1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>2016</td>\n",
       "      <td>6</td>\n",
       "      <td>3</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>3750</td>\n",
       "      <td>1875.000000</td>\n",
       "      <td>1875.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>2016</td>\n",
       "      <td>5</td>\n",
       "      <td>21</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2.0</td>\n",
       "      <td>2</td>\n",
       "      <td>7500</td>\n",
       "      <td>2500.000000</td>\n",
       "      <td>2500.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>2016</td>\n",
       "      <td>4</td>\n",
       "      <td>30</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 228 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   bathrooms  bedrooms  price  price_bathrooms  price_bedrooms  room_diff  \\\n",
       "0        1.0         1   3795      1897.500000     1897.500000        0.0   \n",
       "1        2.0         3   5500      1833.333333     1375.000000       -1.0   \n",
       "2        1.0         2   3100      1550.000000     1033.333333       -1.0   \n",
       "3        1.0         1   3750      1875.000000     1875.000000        0.0   \n",
       "4        2.0         2   7500      2500.000000     2500.000000        0.0   \n",
       "\n",
       "   room_num  Year  Month  Day       ...        walk  walls  war  washer  \\\n",
       "0       2.0  2016      6   28       ...           0      0    0       0   \n",
       "1       5.0  2016      6    4       ...           0      0    0       0   \n",
       "2       3.0  2016      6    3       ...           0      0    1       0   \n",
       "3       2.0  2016      5   21       ...           0      0    0       0   \n",
       "4       4.0  2016      4   30       ...           0      0    0       0   \n",
       "\n",
       "   water  wheelchair  wifi  windows  work  interest_level  \n",
       "0      0           0     0        0     0               1  \n",
       "1      0           0     0        0     0               2  \n",
       "2      0           0     0        0     0               1  \n",
       "3      0           0     0        0     0               2  \n",
       "4      0           0     0        0     0               2  \n",
       "\n",
       "[5 rows x 228 columns]"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 保存训练集和测试集\n",
    "train_data = pd.concat([train_x,train_y], axis=1, ignore_index=False)\n",
    "train_data = train_data.reset_index(drop=True)\n",
    "train_data.to_csv('train_data.csv',index=False)\n",
    "train_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>bathrooms</th>\n",
       "      <th>bedrooms</th>\n",
       "      <th>price</th>\n",
       "      <th>price_bathrooms</th>\n",
       "      <th>price_bedrooms</th>\n",
       "      <th>room_diff</th>\n",
       "      <th>room_num</th>\n",
       "      <th>Year</th>\n",
       "      <th>Month</th>\n",
       "      <th>Day</th>\n",
       "      <th>...</th>\n",
       "      <th>walk</th>\n",
       "      <th>walls</th>\n",
       "      <th>war</th>\n",
       "      <th>washer</th>\n",
       "      <th>water</th>\n",
       "      <th>wheelchair</th>\n",
       "      <th>wifi</th>\n",
       "      <th>windows</th>\n",
       "      <th>work</th>\n",
       "      <th>interest_level</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.0</td>\n",
       "      <td>2</td>\n",
       "      <td>3100</td>\n",
       "      <td>1550.0</td>\n",
       "      <td>1033.333333</td>\n",
       "      <td>-1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>2016</td>\n",
       "      <td>4</td>\n",
       "      <td>6</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2.0</td>\n",
       "      <td>2</td>\n",
       "      <td>6000</td>\n",
       "      <td>2000.0</td>\n",
       "      <td>2000.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>2016</td>\n",
       "      <td>5</td>\n",
       "      <td>3</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0</td>\n",
       "      <td>2400</td>\n",
       "      <td>1200.0</td>\n",
       "      <td>2400.000000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2016</td>\n",
       "      <td>6</td>\n",
       "      <td>28</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.0</td>\n",
       "      <td>2</td>\n",
       "      <td>2825</td>\n",
       "      <td>1412.5</td>\n",
       "      <td>941.666667</td>\n",
       "      <td>-1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>2016</td>\n",
       "      <td>4</td>\n",
       "      <td>16</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1.0</td>\n",
       "      <td>2</td>\n",
       "      <td>2700</td>\n",
       "      <td>1350.0</td>\n",
       "      <td>900.000000</td>\n",
       "      <td>-1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>2016</td>\n",
       "      <td>6</td>\n",
       "      <td>3</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 228 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   bathrooms  bedrooms  price  price_bathrooms  price_bedrooms  room_diff  \\\n",
       "0        1.0         2   3100           1550.0     1033.333333       -1.0   \n",
       "1        2.0         2   6000           2000.0     2000.000000        0.0   \n",
       "2        1.0         0   2400           1200.0     2400.000000        1.0   \n",
       "3        1.0         2   2825           1412.5      941.666667       -1.0   \n",
       "4        1.0         2   2700           1350.0      900.000000       -1.0   \n",
       "\n",
       "   room_num  Year  Month  Day       ...        walk  walls  war  washer  \\\n",
       "0       3.0  2016      4    6       ...           0      0    0       0   \n",
       "1       4.0  2016      5    3       ...           0      0    0       0   \n",
       "2       1.0  2016      6   28       ...           0      0    0       0   \n",
       "3       3.0  2016      4   16       ...           0      0    0       0   \n",
       "4       3.0  2016      6    3       ...           0      0    0       0   \n",
       "\n",
       "   water  wheelchair  wifi  windows  work  interest_level  \n",
       "0      0           0     0        0     0               1  \n",
       "1      0           0     0        0     0               2  \n",
       "2      0           0     0        0     0               2  \n",
       "3      0           0     0        0     0               1  \n",
       "4      0           0     0        0     0               2  \n",
       "\n",
       "[5 rows x 228 columns]"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_data = pd.concat([X_val,y_val], axis=1, ignore_index=False)\n",
    "test_data = test_data.reset_index(drop=True)\n",
    "test_data.to_csv('test_data.csv',index=False)\n",
    "test_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
